#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2011 - 2013 Stefano Mazzucco <stefano -at- curso.re>
# All rights reserved.
#
# This file is part of Crystal Ball Plus.
#
# Crystal Ball Plus is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Crystal Ball Plus is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Crystal Ball Plus.  If not, see <http://www.gnu.org/licenses/>.
#

import os
import sys
import numpy as np
# import cStringIO
from nose import with_setup

# import nose
# sys.stderr.write('%s version: %s\n' % (nose.__name__, nose.__version__))

from crystalballplus.structures import *
from crystalballplus.compare import *
from crystalballplus.operations import *
from crystalballplus.report import *

uc_deg = {}

files = []

D = {}
R = {}
Z = {}
A = {}

def setup_uc_deg():
    # angles in degrees
    uc1 = (5.504, 7.007, 19.243, 41.3, 72.5, 109.4)  # triclinic (V4O7)
    uc2 = (6.74, 13.80, 10.01, 90.0, 113.79, 90.0)   # monoclinic (AsSe)
    uc3 = (16.2440, 16.2440, 9.145, 90., 90., 120.)  # hexagonal (Cs7O)
    uc4 = (6.603, 6.603, 6.603, 31.98, 31.98, 31.98) # rombohedral (Ca2N)
    uc5 = (3.872, 3.872, 3.872, 89., 89., 89.)       # rombohedral (TiF3)
    uc6 = (5.456, 4.814, 11.787, 90., 90., 90.)      # orthorombic (SbO4)
    uc7 = (4.56, 4.56, 10.77, 90., 90., 90.)         # tetragonal (BaGa4)
    uc8 = (3.19, 3.19, 3.19, 90., 90., 90.)          # cubic (AuZn)
    uc9 = (5.225, 5.225, 5.225, 90.0, 90.0, 90.0)    # cubic (K)    
    uc_deg['triclinic (V4O7)'] = uc1
    uc_deg['monoclinic (AsSe)'] = uc2
    uc_deg['hexagonal (Cs7O)'] =  uc3
    uc_deg['rombohedral (Ca2N)'] = uc4
    uc_deg['rombohedral (TiF3)'] = uc5
    uc_deg['orthorombic (SbO4)'] = uc6
    uc_deg['tetragonal (BaGa4)'] = uc7
    uc_deg['cubic (AuZn)'] = uc8
    uc_deg['cubic (K)'] = uc9

uc_rad = {}
def setup_uc_rad():
    # angles in radians
    pi = np.pi
    uc10 = (5.504, 7.007, 19.243,  0.72082,  1.26536,  1.9094) # triclinic(V4O7)
    uc11 = (6.74, 13.80, 10.01, pi/2., 1.9860, pi/2.) # monoclinic (AsSe)
    uc12 = (16.2440, 16.2440, 9.145, pi/2., pi/2., 2./3.*pi) # hexagonal(Cs7O)
    uc13 = (6.603, 6.603, 6.603, 0.55816, 0.55816, 0.55816) # rombohedral (Ca2N)
    uc14 = (3.872, 3.872, 3.872, 1.55334, 1.55334, 1.55334) # rombohedral (TiF3)
    uc15 = (5.456, 4.814, 11.787, pi/2., pi/2.,pi/2.) # orthorombic (SbO4)  
    uc16 = (4.56, 4.56, 10.77, pi/2., pi/2.,pi/2.)    # tetragonal (BaGa4)  
    uc17 = (3.19, 3.19, 3.19, pi/2., pi/2.,pi/2.)     # cubic (AuZn)        
    uc18 = (5.225, 5.225, 5.225, pi/2., pi/2.,pi/2.)  # cubic (K)
    uc_rad['triclinic (V4O7)'] = uc10
    uc_rad['monoclinic (AsSe)'] = uc11
    uc_rad['hexagonal (Cs7O)'] =  uc12
    uc_rad['rombohedral (Ca2N)'] = uc13
    uc_rad['rombohedral (TiF3)'] = uc14
    uc_rad['orthorombic (SbO4)'] = uc15
    uc_rad['tetragonal (BaGa4)'] = uc16
    uc_rad['cubic (AuZn)'] = uc17
    uc_rad['cubic (K)'] = uc18

def check_dmt(uc_dict, deg):
    for uc in uc_dict.itervalues():
        dmt = direct_metric_tensor(uc, deg=deg)
        result = np.all(dmt == dmt.T)
        yield result

def check_inversion(uc_dict, deg):
    for uc in uc_dict.itervalues():
        dmt = direct_metric_tensor(uc, deg=deg)
        rmt = reciprocal_metric_tensor(uc, deg=deg)
        result = np.allclose(np.dot(dmt,rmt),np.eye(3))
        yield result

def setup_files():
    files.append('tests/inp/inp_Si.dfg')
    files.append('tests/inp/inp_SiO2.dfg')
    files.append('tests/ref/ref_Si.dfg')
    files.append('tests/ref/ref_SiO2.dfg')

def teardown_files():
    files = []

def teardown_dict():
    D, R, A = {}, {}, {}

def setup_compare():
    D.update(compare_d('tests/inp/', 'tests/ref/'))
    R.update(group_reflections(D))
    A.update(compare_angles(R))

@with_setup(setup_uc_deg)
def test_uc_deg():
    for result in check_dmt(uc_deg, True):
        assert result

@with_setup(setup_uc_deg)        
def test_inversion_deg():
    for result in check_inversion(uc_deg, True):
        assert result

@with_setup(setup_uc_rad)
def test_uc_rad():
    for result in check_dmt(uc_rad, False):
        assert result

@with_setup(setup_uc_rad)        
def test_inversion_rad():
    for result in check_inversion(uc_rad, False):
        assert result

@with_setup(setup_files, teardown_files)
def test_read_dfg():
    for f in files:
        # sys.stderr.write('testing %s \n' % f)
        assert Diffractogram(f)

@with_setup(setup_compare, teardown_dict)
def test_compare():
    assert D
    assert R
    assert A
    
@with_setup(setup_compare, teardown_dict)
def test_D():    
    assert len(D['inp_Si.dfg']['ref_Si.dfg'][0]) == 3
    assert len(D['inp_Si.dfg']['ref_SiO2.dfg'][0]) == 1    
    assert len(D['inp_SiO2.dfg']['ref_SiO2.dfg'][0]) == 3
    assert D['inp_SiO2.dfg']['ref_Si.dfg'] is None
    
@with_setup(setup_compare, teardown_dict)
def test_A():    
    assert len(A['inp_Si.dfg']['ref_Si.dfg'][0].keys()) == 24
    assert A['inp_Si.dfg']['ref_SiO2.dfg'] is None    
    assert len(A['inp_SiO2.dfg']['ref_SiO2.dfg'][0].keys()) == 24
    assert A['inp_SiO2.dfg']['ref_Si.dfg'] is None
    
    for i in A['inp_SiO2.dfg']['ref_SiO2.dfg'][0]:
        match = 0
        lgt = len(A['inp_SiO2.dfg']['ref_SiO2.dfg'][0][i]['zone_axis'])
        for arr in A['inp_SiO2.dfg']['ref_SiO2.dfg'][0][i]['zone_axis']:
            fam = A['inp_SiO2.dfg']['ref_SiO2.dfg'][1].family((1,0,1))
            for f in fam:
                if np.all(f == arr):
                    match += 1
        assert match == lgt

    for i in A['inp_Si.dfg']['ref_Si.dfg'][0]:
        match = 0
        lgt = len(A['inp_Si.dfg']['ref_Si.dfg'][0][i]['zone_axis'])
        for arr in A['inp_Si.dfg']['ref_Si.dfg'][0][i]['zone_axis']:
            fam = A['inp_Si.dfg']['ref_Si.dfg'][1].family((2,0,1))
            for f in fam:
                if np.all(f == arr):
                    match += 1
        assert match == lgt
